import torch.nn as nn
import torch
import torch.nn.functional as F
from distributions import DiscreteLogisticMixture, Beta, Gaussian, ImageBitwiseCategorical, ImageCategorical


def softclip(tensor, min):
    """ Clips the tensor values at the minimum value min in a softway. Taken from Handful of Trials """
    result_tensor = min + F.softplus(tensor - min)
    
    return result_tensor


class HalfSigmoid(nn.Module):
    def forward(self, x):
        l = int(x.shape[1] / 2)
        x1, x2 = x[:, :l], x[:, l:]
        return torch.cat([F.sigmoid(x1), x2], 1)


class OptimalVarianceGaussian(Gaussian):
    """ Technically not a distribution, however, it can compute NLL by adjusting it's variance to the datum at hand """
    
    def nll(self, x):
        ids = list(range(len(x.shape)))[1:]
        sigma = ((x - self.mu) ** 2).mean(ids, keepdim=True).sqrt()
        return Gaussian(mu=self.mu, sigma=sigma).nll(x)


class OptimalScalarVarianceGaussian(Gaussian):
    """ Technically not a distribution, however, it can compute NLL by adjusting it's variance to the datum at hand """
    
    def nll(self, x):
        ids = list(range(len(x.shape)))
        sigma = ((x - self.mu) ** 2).mean(ids, keepdim=True).sqrt()
        return Gaussian(mu=self.mu, sigma=sigma).nll(x)


class Addition(nn.Module):
    def __init__(self, a):
        super().__init__()
        self.a = a
        
    def forward(self, x):
        return x + self.a

    
def get_distribution_class(distribution, sigma_mode=None):
    activation = nn.Sigmoid()
    n_param = 1
    if distribution == 'gaussian':
        distr = Gaussian
    elif distribution == 'beta':
        distr = Beta
        n_param = 2
        activation = nn.Sequential(nn.Softplus(), Addition(1))
    elif distribution == 'categorical':
        distr = ImageCategorical
        n_param = 256
        activation = None
    elif distribution == 'bitwise_categorical':
        distr = ImageBitwiseCategorical
        n_param = 8
        activation = None
    elif distribution == 'bernoulli':
        distr = None
    elif distribution == 'discrete_logistic_mixture':
        distr = DiscreteLogisticMixture
        n_param = 5
    if sigma_mode == 'optimal' or distribution == 'optimal_gaussian':
        distr = OptimalVarianceGaussian
    if sigma_mode == 'optimal_constant' or distribution == 'optimal_constant_gaussian':
        distr = OptimalScalarVarianceGaussian
        
    return distr, n_param, activation
